/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.subcommands;

import java.util.ArrayList;
import java.util.List;

import mosdi.discovery.ObjectiveFunction;
import mosdi.discovery.ScoreAndPValue;
import mosdi.discovery.objectives.OccurrenceCountObjective;
import mosdi.discovery.objectives.SequenceCountObjective;
import mosdi.fa.Alphabet;
import mosdi.fa.FiniteMemoryTextModel;
import mosdi.fa.IIDTextModel;
import mosdi.fa.MarkovianTextModel;
import mosdi.index.SuffixTree;
import mosdi.util.BitArray;
import mosdi.util.FileUtils;
import mosdi.util.Log;
import mosdi.util.LogSpace;
import mosdi.util.NamedSequence;
import mosdi.util.SequenceUtils;

public class LocalSearchSubcommand extends Subcommand {

	@Override
	public String usage() {
		return
		super.usage()+" [options] <objective> <fasta-file> <iupac-pattern>\n" +
		"\n" +
		"<objective> = [occ-count|seq-count]\n" +
		"\n" +
		"Options:\n" +
		"  -F: read patterns from file\n" +
		"  -r: simultaneously consider reverse complementary motif\n" +
		"  -l <max-length>: bound maximal length of motif\n" +
		"  -M <text-model-order>: order of text model to be estimated from\n" +
		"                         given sequences.\n" +
		"  -q <q-gram-table-file>: Estimate text model from given q-gram table\n" +
		"                          (default: estimate from <fasta-file>) A q-gram table\n" +
		"                          can be created by using \"mosdi-utils count-qgrams\".";
	}

	@Override
	public String description() {
		return "Tries to improve given motif using an (iterated) neighborhood search.";
	}

	@Override
	public String name() {
		return "local-search";
	}

	@Override
	public int run(String[] args) {
		parseOptions(args, 3, "Frl:M:q:");

		// Option dependencies
		exclusiveOptions("q", "M");

		// Mandatory arguments
		String objectiveName =  getStringArgument(0);
		String sequenceFile =  getStringArgument(1);
		String pattern =  getStringArgument(2);

		// Options
		boolean considerReverse = getBooleanOption("r", false);
		int maxLength = getPositiveIntOption("l", Integer.MAX_VALUE);
		int textModelOrder = getNonNegativeIntOption("M", 0);
		String qGramTableFilename = getStringOption("q", null);

		if (!objectiveName.equals("occ-count") && !objectiveName.equals("seq-count")) {
			Log.errorln("Unknown objective: "+objectiveName);
			System.exit(1);	
		}
		
		Alphabet iupacAlphabet = Alphabet.getIupacAlphabet();
		Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		List<NamedSequence> namedSequences = null;
		FiniteMemoryTextModel textModel = null;
		if (qGramTableFilename!=null) {
			try {
				textModel = SequenceUtils.buildTextModelFromQGramFile(qGramTableFilename);
			} catch (Exception e) {
				Log.errorln(e.toString());
				System.exit(1);
			}
		}

		// create list of patterns
		List<String> patternList = null;
		if (optionSet.has("F")) {
			// read patterns from file
			patternList = FileUtils.readPatternFile(pattern);
		} else {
			patternList = new ArrayList<String>(1);
			patternList.add(pattern);
		}
		try {
			namedSequences = SequenceUtils.readFastaFile(sequenceFile, dnaAlphabet);
		} catch (Exception e) {
			Log.errorln(e.toString());
			System.exit(1);
		}
		if (textModel==null) {
			if (textModelOrder==0) {
				textModel = new IIDTextModel(dnaAlphabet.size(), SequenceUtils.sequenceList(namedSequences));
			} else {
				textModel = new MarkovianTextModel(textModelOrder, dnaAlphabet.size(), SequenceUtils.sequenceList(namedSequences));
			}
		}
		
		// total sequence length
		int sequenceLength = 0;
		for (NamedSequence ns : namedSequences) sequenceLength+=ns.length();

		SuffixTree suffixTree = new SuffixTree(SequenceUtils.concatSequences(namedSequences, considerReverse), dnaAlphabet.size()); 
		
		ObjectiveFunction objective = null;
		if (objectiveName.equals("occ-count")) {
			int[] occurrenceCountAnnotation = suffixTree.calcOccurrenceCountAnnotation();
			objective = new OccurrenceCountObjective(textModel, occurrenceCountAnnotation, null, -1, false);
		}
		if (objectiveName.equals("seq-count")) {
			int[] sequenceLengths = new int[namedSequences.size()];
			for (int i=0; i<namedSequences.size(); ++i) {
				sequenceLengths[i] = namedSequences.get(i).length();
			}
			List<int[]> l = SequenceUtils.appendMinusOnes(namedSequences, considerReverse);
			BitArray[] sequenceOccurrenceAnnotation = suffixTree.calcSequenceOccurrenceAnnotation(l);
			objective = new SequenceCountObjective(textModel, sequenceLengths, sequenceOccurrenceAnnotation, false);
		}
		Log.startTimer();
		final int maxBorderSpacer = 2; 
		for (String originalPattern : patternList) {
			Log.startTimer();
			
			String bestPattern = originalPattern;
			ScoreAndPValue bestScore = new ScoreAndPValue(0, 0.0);
			while (true) {
				List<String> candidates = new ArrayList<String>();
				candidates.add(bestPattern);
				// generates candidates in the neighborhood ...
				// ... by changing characters
				for (int i=0; i<bestPattern.length(); ++i) {
					int currentChar = iupacAlphabet.getIndex(bestPattern.charAt(i));
					for (int c=0; c<iupacAlphabet.size(); ++c) {
						if (c==currentChar) continue;
						StringBuilder s = new StringBuilder(bestPattern);
						s.setCharAt(i, iupacAlphabet.get(c));
						candidates.add(s.toString());
					}
				}
				// ... by appending characters at either end
				for (int spacers=0; spacers<=maxBorderSpacer; ++spacers) {
					if (bestPattern.length()+1+spacers>maxLength) continue;
					for (int c=0; c<iupacAlphabet.size(); ++c) {
						if (iupacAlphabet.get(c)=='N') continue;
						StringBuilder s = new StringBuilder(bestPattern);
						for (int i=0; i<spacers; ++i) s.append('N');
						s.append(iupacAlphabet.get(c));
						candidates.add(s.toString());
						s = new StringBuilder();
						s.append(iupacAlphabet.get(c));
						for (int i=0; i<spacers; ++i) s.append('N');
						s.append(bestPattern);
						candidates.add(s.toString());
					}
				}
				// evaluate all candidates
				// bestPattern = null;
				for (String s : candidates) {
					ScoreAndPValue score = objective.staticEvaluate(iupacAlphabet.buildIndexArray(s), considerReverse, suffixTree, bestScore.getPValue());
					Log.printf(Log.Level.DEBUG, "Evaluated %s: %s\n", s, (score==null)?"null":score.toString());
					if (score==null) continue;
					if (score.compareTo(bestScore)>0) {
						bestPattern = s;
						bestScore = score;
					}
				}
				StringBuilder sb = new StringBuilder();
				sb.append(String.format("Patterns examined: %d, current motif: %s %s", candidates.size(), bestPattern, bestScore.toString()));
				if (sb.length()>0) Log.println(Log.Level.STANDARD, sb.toString());
				if (bestPattern==candidates.get(0)) break; 
			}
			Log.stopTimer("Total time for this motif");

			Log.println(Log.Level.STANDARD, "Format: >> pattern #occurrence pvalue");
			StringBuilder sb = new StringBuilder();
			sb.append(String.format(">> %s %d %s", bestPattern, bestScore.getScore(), LogSpace.toString(-bestScore.getMinusLogPValue())));
			//			sb.append(String.format(">>stats>> %s %d %d %d %d %e ", forwardPattern, genStringList.size(), states, statesMinimal, m.getMatches(), expectation));
			//			sb.append(String.format(">>poisson>> %e %e ", lambda, expectedClumpSize));
			//			sb.append(Log.format(">>runtimes>> %t %t %t ", timeClumpSizeDist, timeConvolution, timeMotif));
			//			if (pValueTable) {
			//			sb.append(">>p_value_table>> ");
			//			for (double d : dist) sb.append(String.format("%e ", d));
			//			}
			if (sb.length()>0) Log.println(Log.Level.STANDARD, sb.toString());

		}
		return 0;
	}

}
